{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How-to\n",
    "\n",
    "1. You need to use [modeling.py](modeling.py) from extractive-summarization folder. An improvement BERT model to accept text longer than 512 tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset-bert.pkl', 'rb') as fopen:\n",
    "    dataset = pickle.load(fopen)\n",
    "dataset.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'\n",
    "BERT_INIT_CHKPNT = 'uncased_L-12_H-768_A-12/bert_model.ckpt'\n",
    "BERT_CONFIG = 'uncased_L-12_H-768_A-12/bert_config.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import bert\n",
    "from bert import run_classifier\n",
    "from bert import optimization\n",
    "from bert import tokenization\n",
    "import modeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenization.validate_case_matches_checkpoint(True,BERT_INIT_CHKPNT)\n",
    "tokenizer = tokenization.FullTokenizer(\n",
    "      vocab_file=BERT_VOCAB, do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 20\n",
    "batch_size = 8\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(dataset['train_texts']) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        learning_rate = 2e-5,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.float32, [None, None])\n",
    "        self.mask = tf.placeholder(tf.int32, [None, None])\n",
    "        self.clss = tf.placeholder(tf.int32, [None, None])\n",
    "        mask = tf.cast(self.mask, tf.float32)\n",
    "        \n",
    "        model = modeling.BertModel(\n",
    "            config=bert_config,\n",
    "            is_training=True,\n",
    "            input_ids=self.X,\n",
    "            input_mask=self.input_masks,\n",
    "            token_type_ids=self.segment_ids,\n",
    "            use_one_hot_embeddings=False)\n",
    "        \n",
    "        outputs = tf.gather(model.get_sequence_output(), self.clss, axis = 1, batch_dims = 1)\n",
    "        self.logits = tf.layers.dense(outputs, 1)\n",
    "        self.logits = tf.squeeze(self.logits, axis=-1)\n",
    "        self.logits = self.logits * mask\n",
    "        crossent = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.Y)\n",
    "        crossent = crossent * mask\n",
    "        crossent = tf.reduce_sum(crossent)\n",
    "        total_size = tf.reduce_sum(mask)\n",
    "        self.cost = tf.div_no_nan(crossent, total_size)\n",
    "        \n",
    "        self.optimizer = optimization.create_optimizer(self.cost, learning_rate, \n",
    "                                                       num_train_steps, num_warmup_steps, False)\n",
    "        \n",
    "        l = tf.round(tf.sigmoid(self.logits))\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(tf.boolean_mask(l, tf.equal(self.Y, 1)), tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(learning_rate = 1e-5)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess.run(tf.global_variables_initializer())\n",
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, BERT_INIT_CHKPNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = dataset['train_texts']\n",
    "test_X = dataset['test_texts']\n",
    "train_clss = dataset['train_clss']\n",
    "test_clss = dataset['test_clss']\n",
    "train_Y = dataset['train_labels']\n",
    "test_Y = dataset['test_labels']\n",
    "train_segments = dataset['train_segments']\n",
    "test_segments = dataset['test_segments']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
    "    train_loss, train_acc, test_loss, test_acc = [], [], [], []\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x, _ = pad_sentence_batch(train_X[i : index], 0)\n",
    "        batch_y, _ = pad_sentence_batch(train_Y[i : index], 0)\n",
    "        batch_segments, _ = pad_sentence_batch(train_segments[i : index], 0)\n",
    "        batch_clss, _ = pad_sentence_batch(train_clss[i : index], -1)\n",
    "        batch_clss = np.array(batch_clss)\n",
    "        batch_x = np.array(batch_x)\n",
    "        batch_mask = 1 - (batch_clss == -1)\n",
    "        batch_clss[batch_clss == -1] = 0\n",
    "        mask_src = 1 - (batch_x == 0)\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "                model.mask: batch_mask,\n",
    "                model.clss: batch_clss,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: mask_src}\n",
    "        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],\n",
    "                                    feed_dict = feed)\n",
    "        train_loss.append(loss)\n",
    "        train_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x, _ = pad_sentence_batch(test_X[i : index], 0)\n",
    "        batch_y, _ = pad_sentence_batch(test_Y[i : index], 0)\n",
    "        batch_segments, _ = pad_sentence_batch(test_segments[i : index], 0)\n",
    "        batch_clss, _ = pad_sentence_batch(test_clss[i : index], -1)\n",
    "        batch_clss = np.array(batch_clss)\n",
    "        batch_x = np.array(batch_x)\n",
    "        batch_mask = 1 - (batch_clss == -1)\n",
    "        batch_clss[batch_clss == -1] = 0\n",
    "        mask_src = 1 - (batch_x == 0)\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "                model.mask: batch_mask,\n",
    "                model.clss: batch_clss,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: mask_src}\n",
    "        accuracy, loss = sess.run([model.accuracy,model.cost],\n",
    "                                    feed_dict = feed)\n",
    "\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,\n",
    "                                                                 np.mean(train_loss),np.mean(train_acc)))\n",
    "    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,\n",
    "                                                              np.mean(test_loss),np.mean(test_acc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
